{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d968de3f",
   "metadata": {},
   "source": [
    "# Automatic Mixed-Precision (AMP)\n",
    "\n",
    "This notebook shows a working code example of how to use AIMET to perform Auto Mixed Precision (AMP). AMP is a technique where given a quantized accuracy target, AIMET finds bit-precision per-layer to meet that accuracy target while trying to optimize the model for inference speed.\n",
    "\n",
    "As an example, say a particular model is not meeting a desired accuracy target when run in INT8. The Auto Mixed Precision feature will find a minimal set of layers that need to run on say INT16 to get to the desired accuracy. It should be noted that choosing higher precision for some layers necessarily involves a trade-off: lower inferences/sec for higher accuracy and vice-versa.\n",
    "\n",
    "Alternatively, the AMP feature can be used to generate a pareto curve (accuracy vs. bit-ops) that can guide the user to decide the right operating point for this tradeoff.\n",
    "\n",
    "This notebook specifically shows working code example for the above. \n",
    "\n",
    "#### Overall flow\n",
    "This notebook covers the following\n",
    "1. Instantiate the example evaluation method\n",
    "2. Load the FP32 model and evaluate the model to find the baseline FP32 accuracy\n",
    "3. Create a quantization simulation model (with fake quantization ops inserted)\n",
    "4. Run AMP algorithm on the quantized model \n",
    "    1. Using the Regular AMP method\n",
    "    2. Using the Fast AMP Method (AMP 2.0)\n",
    "\n",
    "#### What this notebook is not\n",
    "* This notebook is not designed to show state-of-the-art AMP results. For example, it uses a relatively quantization-friendly model like ResNet50. Also, some optimization parameters like number of samples for evaluation are deliberately chosen to have the notebook execute more quickly."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3ffda97",
   "metadata": {},
   "source": [
    "---\n",
    "## Dataset\n",
    "\n",
    "This notebook relies on the ImageNet dataset for the task of image classification. If you already have a version of the dataset readily available, please use that. Else, please download the dataset from appropriate location (e.g. https://image-net.org/challenges/LSVRC/2012/index.php#).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cfc8879",
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET_DIR = '/path/to/dataset'        # Please replace this with a real directory\n",
    "BATCH_SIZE = 32"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3af3b030",
   "metadata": {},
   "source": [
    "We disable logs at the INFO level. We set verbosity to the level as displayed (ERORR), so TensorFlow will display all messages that have the label ERROR (or more critical)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f261881d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = \"2\"\n",
    "import tensorflow as tf\n",
    "tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c23fb65c",
   "metadata": {},
   "source": [
    "---\n",
    "## 1. Instantiate the example evaluation method\n",
    "\n",
    "The following is an example evlauation method which we will used to evaluate the accuracy for the model as well to perform a forward pass on the model. AIMET needs forward pass for calculating the range of values at activations of each layer.Below is an example function which we will use for both evaluation and the forward pass."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42d73327",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.applications.resnet import preprocess_input, decode_predictions\n",
    "\n",
    "def center_crop(image):\n",
    "    \n",
    "    img_height = 256\n",
    "    img_width = 256\n",
    "    crop_length = 224\n",
    "\n",
    "    start_x = (img_height - crop_length) // 2\n",
    "    start_y = (img_width - crop_length) // 2\n",
    "    cropped_image=image[: ,  start_x:(img_width - start_x), start_y:(img_height - start_y), :]\n",
    "\n",
    "    return cropped_image\n",
    "\n",
    "\n",
    "def get_eval_func(dataset_dir, batch_size, num_iterations=50000, debug=False, get_top5_acc=False):\n",
    "        \n",
    "    def func_wrapper(model, iterations=None, use_cuda=True):\n",
    "\n",
    "        validation_ds = tf.keras.preprocessing.image_dataset_from_directory(\n",
    "            directory=dataset_dir,\n",
    "            labels='inferred',\n",
    "            label_mode='categorical',\n",
    "            batch_size=batch_size,\n",
    "            shuffle = False,\n",
    "            image_size=(256, 256))\n",
    "        # If no iterations specified, set to full validation set\n",
    "        if not iterations:\n",
    "            iterations = num_iterations\n",
    "        else:\n",
    "            iterations = iterations * batch_size\n",
    "        top1 = 0\n",
    "        top5 = 0\n",
    "        total = 0\n",
    "        for (img,label) in validation_ds:\n",
    "            img = center_crop(img)\n",
    "            x = preprocess_input(img)\n",
    "            preds = model.predict(x,batch_size = batch_size)\n",
    "            label = np.where(label)[1]\n",
    "            label = [validation_ds.class_names[int(i)] for i in label]\n",
    "            cnt = sum([1 for a, b in zip(label, decode_predictions(preds, top=1)) if str(a) == b[0][0]])\n",
    "            top1 += cnt\n",
    "            cnt = sum([1 for a, b in zip(label, decode_predictions(preds, top=5)) if str(a) in [i[0] for i in b]])\n",
    "            top5 += cnt\n",
    "            total += len(label)\n",
    "            if total >= iterations:\n",
    "                break\n",
    "        if get_top5_acc == True:\n",
    "            return top1/total, top5/total\n",
    "        else:\n",
    "            return top1/total\n",
    "    return func_wrapper\n",
    "\n",
    "\n",
    "\n",
    "# Instantiate the evaluation function\n",
    "eval_func = get_eval_func(DATASET_DIR, BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdbf83fa",
   "metadata": {},
   "source": [
    "---\n",
    "## 2. Load the FP32 model and evaluate the model to find the baseline FP32 accuracy\n",
    "\n",
    "For this example notebook, we are going to load a pretrained ResNet50 model from keras . Similarly, you can load any pretrained tensorflow model instead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b20fb56b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.applications.resnet import ResNet50\n",
    "from aimet_tensorflow.keras.batch_norm_fold import fold_all_batch_norms\n",
    "\n",
    "def get_model():\n",
    "    model = ResNet50(\n",
    "        include_top=True,\n",
    "        weights=\"imagenet\",\n",
    "        input_tensor=None,\n",
    "        input_shape=None,\n",
    "        pooling=None,\n",
    "        classes=1000)\n",
    "\n",
    "    return model\n",
    "\n",
    "model = get_model()\n",
    "# We will perform the batch norm folding on the loaded model.\n",
    "\n",
    "_ = fold_all_batch_norms(model)\n",
    "\n",
    "# calculate the FP32 model acccuracy\n",
    "\n",
    "fp32_acccuracy = eval_func(model, None)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35720911",
   "metadata": {},
   "source": [
    "---\n",
    "## 3.Create a quantization simulation model (with fake quantization ops inserted)\n",
    "\n",
    "\n",
    "Now we use AIMET to create a QuantizationSimModel. This basically means that AIMET will insert fake quantization ops in the model graph and will configure them.\n",
    "\n",
    "A few of the parameters are explained here\n",
    "- **quant_scheme**: We set this to \"QuantScheme.post_training_tf_enhanced\"\n",
    "    - Supported options are 'tf_enhanced' or 'tf' or using Quant Scheme Enum QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced\n",
    "- **default_output_bw**: Setting this to 8, essentially means that we are asking AIMET to perform all activation quantizations in the model using integer 8-bit precision\n",
    "- **default_param_bw**: Setting this to 8, essentially means that we are asking AIMET to perform all parameter quantizations in the model using integer 8-bit precision\n",
    "\n",
    "There are other parameters that are set to default values in this example. Please check the AIMET API documentation of QuantizationSimModel to see reference documentation for all the parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26ccd7df",
   "metadata": {},
   "outputs": [],
   "source": [
    "from aimet_common.defs import QuantScheme\n",
    "from aimet_tensorflow.keras.quantsim import QuantizationSimModel\n",
    "\n",
    "sim = QuantizationSimModel(\n",
    "        model=model,\n",
    "        quant_scheme=QuantScheme.post_training_tf_enhanced,\n",
    "        rounding_mode=\"nearest\",\n",
    "        default_output_bw=8,\n",
    "        default_param_bw=8\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "854aff28",
   "metadata": {},
   "source": [
    "## Compute Encodings\n",
    "Even though AIMET has added 'quantizer' nodes to the model graph but the model is not ready to be used yet. Before we can use the sim model for inference or training, we need to find appropriate scale/offset quantization parameters for each 'quantizer' node. For activation quantization nodes, we need to pass unlabeled data samples through the model to collect range statistics which will then let AIMET calculate appropriate scale/offset quantization parameters. This process is sometimes referred to as calibration. AIMET simply refers to it as 'computing encodings'.\n",
    "\n",
    "The following shows an example of a routine that passes unlabeled samples through the model for computing encodings. This routine can be written in many different ways, this is just an example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6984cc14",
   "metadata": {},
   "outputs": [],
   "source": [
    "sim.compute_encodings(eval_func, forward_pass_callback_args=500)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b78c729",
   "metadata": {},
   "source": [
    "---\n",
    "## 4. Run AMP algorithm on the quantized model\n",
    "\n",
    "AMP algorithm runs in 3 phases (phase-3 is optional). Phase-1 comprises of calculating the sensitivity for each layer. Phase-2 comprises of greedily selecting which layers should have what bitwidth based on options provided by the user. Phase-3 derives a set of mixed-precision solutions having less bitwidth convert op overhead compared to original phase-2 solution. For phase 1 and phase 2 we require to pass data through the model. \n",
    "\n",
    "So we create a routine to pass unlabeled data samples through the model. This should be fairly simple - use the existing train or validation data loader to extract some samples and pass them to the model. We don't need to compute any loss metric etc. So we can just ignore the model output for this purpose. A few pointers regarding the data samples\n",
    "- In practice, we need a very small percentage of the overall data samples for computing encodings. For example, the training dataset for ImageNet has 1M samples. For phase 1 we only need 500 or 1000 samples. For phase 2 it is recommended to use all of validation data. This is done to speed up AMP execution. Therefore, we define 2 separate functions for phase 1 and phase 2. \n",
    "- For phase 2, if a large-enough subset of the samples provide a meaningful accuracy score, we can use the subset of samples to speed up the AMP algorithm\n",
    "- It may be beneficial if the samples used for forward pass are well distributed. It's not necessary that all classes need to be covered etc. since we are only looking at the range of values at every layer activation. However, we definitely want to avoid an extreme scenario like all 'dark' or 'light' samples are used - e.g. only using pictures captured at night might not give ideal results.\n",
    "We have two method for doing AMP in Keras. One can opt for any one of the methods.\n",
    "    1. Regular AMP \n",
    "    2. Fast AMP (AMP 2.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "321952f9",
   "metadata": {},
   "source": [
    "## Parameters for AMP algorithm\n",
    "\n",
    "A few of the parameters required for AMP are explained below\n",
    "\n",
    "- **candidates** : It is a list of tuples for all possible bitwidth values for activations and parameters. Suppose the possible combinations are-((Activation bitwidth - 8, Activation data type - int), (Parameter bitwidth - 16, parameter data type - int)) ((Activation bitwidth - 16, Activation data type - float), (Parameter bitwidth - 16, parameter data type - float)) candidates will be [((8, QuantizationDataType.int), (16, QuantizationDataType.int)), ((16, QuantizationDataType.float), (16, QuantizationDataType.float))]\n",
    "- **allowed_accuracy_drop** : Maximum allowed drop in accuracy from FP32 baseline. The pareto front curve is plotted only till the point where the allowable accuracy drop is met. To get a complete plot for picking points on the curve, the user can set the allowable accuracy drop to None.\n",
    "- **results_dir** : Path to save results and cache intermediate results\n",
    "- **clean_start** : If true, any cached information from previous runs will be deleted prior to starting the mixed-precision analysis. If false, prior cached information will be used if applicable. Note it is the user's responsibility to set this flag to true if anything in the model or quantization parameters changes compared to the previous run.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08519307",
   "metadata": {},
   "outputs": [],
   "source": [
    "from aimet_common.defs import QuantizationDataType\n",
    "\n",
    "candidates = [((16, QuantizationDataType.int), (8, QuantizationDataType.int)),\n",
    "              ((8, QuantizationDataType.int), (8, QuantizationDataType.int))]\n",
    "\n",
    "allowed_accuracy_drop = 0.01\n",
    "\n",
    "results_dir = '/path/to/where/we/want/to/store/intermediate/and/final/results'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6effa308",
   "metadata": {},
   "source": [
    "## Regular AMP\n",
    "In this we will need regular eval function as discussed above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac064c39",
   "metadata": {},
   "outputs": [],
   "source": [
    "from aimet_common.defs import CallbackFunc\n",
    "\n",
    "eval_callback_phase1 = CallbackFunc(eval_func, 500)\n",
    "eval_callback_phase2 = CallbackFunc(eval_func, None)\n",
    "forward_pass_call_back = CallbackFunc(eval_func, 500)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c78ecaf",
   "metadata": {},
   "source": [
    "### API Call for Regular AMP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f96324f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from aimet_tensorflow.keras.mixed_precision import choose_mixed_precision\n",
    "from aimet_tensorflow.keras.amp.mixed_precision_algo import GreedyMixedPrecisionAlgo\n",
    "\n",
    "# Enable phase-3 (optional)\n",
    "# GreedyMixedPrecisionAlgo.ENABLE_CONVERT_OP_REDUCTION = True\n",
    "# Note: supported candidates ((8,int), (8,int)) & ((16,int), (8,int))\n",
    "\n",
    "choose_mixed_precision(sim, candidate, eval_callback_phase1, eval_callback_phase2, allowed_accuracy_drop,\n",
    "                      results_dir, True, forward_pass_call_back)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99886f1a",
   "metadata": {},
   "source": [
    "## Fast AMP (AMP 2.0)\n",
    "\n",
    "In this method of AMP instead of using the acuracy score for the evaluation in the phase one we use the SQNR score. This speeds up the phase 1 computation of the AMP and saves some time. To use this version we need require a data loader wrapper instead of phase 1 evaluation callback. Below is a sample code for the same and also a sample call to fast AMP API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0bb9e33",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data_loader_wrapper(dataset_dir, batch_size, is_training=False):\n",
    "\n",
    "    def dataloader_wrapper():\n",
    "        dataloader = tf.keras.preprocessing.image_dataset_from_directory(\n",
    "            directory=dataset_dir,\n",
    "            labels='inferred',\n",
    "            label_mode='categorical',\n",
    "            batch_size=batch_size,\n",
    "            shuffle = is_training,\n",
    "            image_size=(256, 256))\n",
    "\n",
    "        return dataloader.map(lambda x, y: preprocess_input(center_crop(x)))\n",
    "\n",
    "    return dataloader_wrapper\n",
    "\n",
    "data_loader_wrapper = get_data_loader_wrapper(DATASET_DIR, BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "334dbd82",
   "metadata": {},
   "outputs": [],
   "source": [
    "from aimet_tensorflow.keras.mixed_precision import choose_fast_mixed_precision\n",
    "from aimet_tensorflow.keras.amp.mixed_precision_algo import GreedyMixedPrecisionAlgo\n",
    "\n",
    "# Enable phase-3 (optional)\n",
    "# GreedyMixedPrecisionAlgo.ENABLE_CONVERT_OP_REDUCTION = True\n",
    "# Note: supported candidates ((8,int), (8,int)) & ((16,int), (8,int))\n",
    "\n",
    "choose_fast_mixed_precision(sim, candidate, data_loader_wrapper, eval_callback_phase2, allowed_accuracy_drop,\n",
    "                      results_dir, True, forward_pass_call_back)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc65edba",
   "metadata": {},
   "source": [
    "---\n",
    "So we have a Mixed precision model after AMP. Now the next step would be to actually take this model to target. For this purpose, we need to export the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "177402df",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs('./output/', exist_ok=True)\n",
    "sim.export(path='./output/', filename_prefix='resnet50_after_amp')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c7bea53",
   "metadata": {},
   "source": [
    "---\n",
    "## Summary\n",
    "\n",
    "Hope this notebook was useful for you to understand how to use Automatic Mixed Precision in Keras. For more details about the parameters and configuration please refer api docs for mixed precision."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f542dda9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
